'''
Author: duliang thinktanker@163.com
Date: 2023-12-17 13:05:57
LastEditors: duliang thinktanker@163.com
LastEditTime: 2025-09-29 11:26:01
# FilePath: 
Description: 
'''
from paddleocr import PaddleOCR
import paddle
import time
import uvicorn
from fastapi import FastAPI, Form, Body, File, Request, Response, HTTPException, status
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import FileResponse, JSONResponse
# from jinja2 import Environment, FileSystemLoader
from fastapi.templating import Jinja2Templates
import json
# from fastapi.responses import JSONResponse
# from fastapi.routing import APIRoute
from threading import Thread
# import os
import base64
# from multiprocessing import Process
# import hashlib
# from functools import lru_cache
from mqttser import mqtt_run
from getipaddress import *
from getsq import *

MAX_G = 1.5
MAX_MEMORRY_SIZE = 1024 * 1024 * 1024 * MAX_G  # 最大内存限制为3G
# #一旦不再使用即释放内存垃圾，=1.0 垃圾占用内存大小达到 10G 时，释放内存垃圾
# os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
# #启用快速垃圾回收策略，不等待 cuda kernel 结束，直接释放显存
# os.environ['FLAGS_fast_eager_deletion_mode'] = "True"
# os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
# #该环境变量设置只占用 0%的显存
# os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = "0.5"
# # ------ CUDA 设置 ------ #
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # 获取GPU设备，不使用GPU时注释
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 设置GPU编号，不使用GPU时注释
# ocr = None
# 需要屏蔽的IP列表
# ALLOWED_IPS = ["127.0.0.1", "192.168.15.101", "223.68.11.42"]
print("载入文本识别")
paddle.set_device("gpu:0")
# paddle.device.cuda.max_memory_reserved("gpu:0")
# --cls_model_dir="./inference/ch_ppocr_mobile_v2.0_cls_infer/"     # 分类模型所在文件夹
cls_model_dir = "./whl/cls/ch/ch_ppocr_mobile_v2.0_cls_infer/"
# --det_model_dir="./inference/ch_ppocr_mobile_v2.0_det_infer/"     # 检测模型所在文件夹
# det_model_dir = "D:/paddle_serving_files/det/ch/ch_PP-OCRv4_det_server_infer/"
det_model_dir = r".\whl\det\ch\ch_PP-OCRv4_det_infer"
# --rec_model_dir="./inference/ch_ppocr_mobile_v2.0_rec_infer/"     # 识别模型所在文件夹
# rec_model_dir = "D:/paddle_serving_files/rec/ch/ch_PP-OCRv4_rec_server_infer/"
rec_model_dir = r".\whl\rec\ch\ch_PP-OCRv4_rec_infer"

# --use_angle_cls=True --use_space_char=True --use_gpu=False
use_angle_cls = True  # 是否加载分类模型
use_space_char = True  # 是否识别空格
use_gpu = True  # 是否使用GPU，若想利用gpu，设置为True
cls = False  #   前向时是否启动分类
# 实例化PaddleOCR识别类
my_ocr = PaddleOCR(
    cls_model_dir=cls_model_dir,  #
    det_model_dir=det_model_dir,  #
    rec_model_dir=rec_model_dir,  #
    use_angle_cls=use_angle_cls,  #
    # det_max_side_len=680,  #
    # # det_limit_side_len=680,  #
    # max_text_length=25,  #
    # image_shape='320,320',
    # # rec_image_shape='320,320',
    # rec_batch_num=10,
    # lang='ch',
    # # use_tensorrt=True,
    # use_fp16=True,
    # ocr_version='PP-OCRv4',
    use_gpu=use_gpu)
# 设置Jinja2模板目录
templates = Jinja2Templates(directory="templates")


def ocr_box(data):
    with open('result.jpg', 'wb') as f:
        f.write(data)
    # global ocr
    # paddle.device.cuda.empty_cache()
    results = my_ocr.ocr(
        data,
        #   det=False,
        #   rec=False,
        cls=False)
    print(f"识别后显存（阈值：{MAX_G}G）：\033[33m",
          paddle.device.cuda.memory_reserved() // 1024 // 1024, 'M\033[0m')
    if paddle.device.cuda.memory_reserved() > MAX_MEMORRY_SIZE:
        paddle.device.cuda.synchronize(0)
        paddle.device.cuda.empty_cache()
        print(f"清理后：\033[33m",
              paddle.device.cuda.memory_reserved() // 1024 // 1024, 'M\033[0m')
    if results:
        sendContent = []
        sendBox = []
        for result in results:
            if result:
                for t0 in result:
                    # sendContent += ' ' + str(t0[-1][0])
                    sendBox.append(t0[0])
                    sendContent.append(str(t0[-1][0]))
            # print(sendContent)
            return sendContent, sendBox, results
        else:
            return [], [], []
    else:
        return [], [], []


def save_ip(ip, ban):
    try:
        conn = sqlite3.connect(r'D:\myGitee\control-net\data\users.db')
        cur = conn.cursor()
        sql = f'''SELECT count FROM "ip" WHERE "ip"="{ip}"'''
        rs = conn.execute(sql)
        rf = rs.fetchone()
        # print(rf)
        if rf:
            count = rf[0]
            sql = f'''UPDATE "ip" SET "count"={count+1},"last_time"={int(time.time())},"ban"={ban} WHERE "ip"="{ip}"'''
            cur.execute(sql)
            conn.commit()
        else:
            ip_address = get_ip_address([ip])
            if ip_address:
                ip_address = ip_address[0]
            sql = f'''INSERT INTO "ip" ("ip", "count","time","last_time","ban","address") VALUES ("{ip}", 1,{int(time.time())},{int(time.time())},{ban},"{ip_address}")'''
            cur.execute(sql)
            conn.commit()
    except Exception as e:
        print(e)
    finally:
        conn.close()


app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    # 允许跨域的源列表，例如 ["http://www.example.org"] 等等，["*"] 表示允许任何源
    allow_origins=["*"],
    # 跨域请求是否支持 cookie，默认是 False，如果为 True，allow_origins 必须为具体的源，不可以是 ["*"]
    allow_credentials=False,
    # 允许跨域请求的 HTTP 方法列表，默认是 ["GET"]
    allow_methods=["*"],
    # 允许跨域请求的 HTTP 请求头列表，默认是 []，可以使用 ["*"] 表示允许所有的请求头
    # 当然 Accept、Accept-Language、Content-Language 以及 Content-Type 总之被允许的
    allow_headers=["*"],
    # 可以被浏览器访问的响应头, 默认是 []，一般很少指定
    # expose_headers=["*"]
    # 设定浏览器缓存 CORS 响应的最长时间，单位是秒。默认为 600，一般也很少指定
    # max_age=1000
)


@app.middleware("https")
async def check_ip(request: Request, call_next):
    # 获取X-Forwarded-For头信息
    x_forwarded_for = request.headers.get("X-Forwarded-For")
    # print(x_forwarded_for)

    # 处理X-Forwarded-For可能的格式
    if x_forwarded_for:
        # 通常取第一个IP作为客户端IP，注意可能存在伪造风险
        client_ip = x_forwarded_for.split(",")[0]
    else:
        # 如果没有X-Forwarded-For头，则使用remote_addr（这通常是代理服务器的IP）
        client_ip = request.client.host

    print(
        f"\033[33m{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))} Client IP: {client_ip}\033[0m"
    )
    Thread(target=save_ip, args=(client_ip, 1)).start()
    with open('./allowed_ips.json', "r") as f:
        ALLOWED_IPS = json.load(f)['allowed_ips']
    if client_ip not in ALLOWED_IPS:
        return JSONResponse(
            status_code=403,
            content={"detail": f"Access from IP {client_ip} is not allowed."},
        )
        # raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
        #                     detail="Access Forbidden")
    response = await call_next(request)
    return response


@app.post("/ocr")
async def read_items(imgbin: bytes = File(...)):
    return ocr_box(imgbin)[:2]


@app.post("/ocrsajj")
# 水安将军
async def read_items(imgbin: bytes = File(...)):
    return ocr_box(imgbin)[-1]


@app.get("/ocr")
async def get_ocr():
    return 'ok'


@app.get("/phone", response_class=HTMLResponse)
async def phone():
    # 渲染模板并传递变量给模板
    template = templates.get_template("phone.html")
    html = template.render(title="我的网页", message="欢迎来到我的网站!")
    return HTMLResponse(content=html, status_code=200)


@app.post("/capture")
async def capture_image(image_data: dict):
    image_base64 = image_data.get('imageBase64')
    uuid = image_data.get('uuid')
    print(uuid)
    if not image_base64:
        raise HTTPException(status_code=400, detail="No image data received")
    try:
        # 解码Base64数据
        image_bytes = base64.b64decode(image_base64.split(',')[1])
        results = my_ocr.ocr(image_bytes, cls=False)
        # # 将字节数据转换为PIL图像
        # image = Image.open(io.BytesIO(image_bytes))
        # # 保存到服务器上的指定目录
        # save_path = "captured_images/frame.jpg"
        # image.save(save_path)
        return {"status": "success", "message": f"{results}"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/sq")
async def sq(gcname: str):
    s = getnewsq(gcname)
    return {"sq": s}


@app.get("/qs")
async def qs(gcname: str):
    qs_png_bytes = getqs(gcname)
    # return {"qs": s}
    return Response(content=qs_png_bytes, media_type="image/png")


@app.get("/wx")
async def wx(gcname: str, gctype: int):
    wx_png_bytes = getwx(gcname, gctype)
    # return {"qs": s}
    return Response(content=wx_png_bytes, media_type="image/jpg")


@app.get('/gongkong')
async def gongkong():
    # gk_png_bytes = getgk(gcname)
    return {"qs": 'ok'}
    # return Response(content=gk_png_bytes, media_type="image/jpg)


def auto_empty_cache():
    print("自动回收显存开启")
    print("每10s施放显存")
    while 1:
        gpu_memory = paddle.device.cuda.memory_reserved() // 1024 // 1024
        if gpu_memory > 3000:
            print('10S清理已用显存（阈值：3G）：', gpu_memory, 'M')
            paddle.device.cuda.empty_cache()
        time.sleep(10)


if __name__ == "__main__":
    # global ocr, BLOCKED_IPS

    # p = Thread(target=auto_empty_cache)
    # p.start()
    th = Thread(target=mqtt_run, args=('文字识别', [("software", 0)]), daemon=True)
    th.start()
    uvicorn.run(app="localORC:app", host='0.0.0.0', port=28211)
